import torch
import torch.nn as nn

class LinearApproximator(nn.Module):
    def __init__(self, input_size):
        super(LinearApproximator, self).__init__()
        self.input_size = input_size
        self.W = nn.Parameter(torch.randn(input_size, input_size))
        self.b = nn.Parameter(torch.randn(input_size))

    def forward(self, x):
        # x: [B, input_size]
        # W: [input_size, input_size]
        # b: [input_size]
        # output: [B, input_size]
        return torch.matmul(x, self.W) + self.b

class ConvApproximator(nn.Module):
            def __init__(self,input_size,feature_dim):
                super(ConvApproximator,self).__init__()
                self.input_size = input_size
                
                self.encoder = nn.Sequential(
                    nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),  # H/2
                    nn.ReLU(inplace=True),
                    nn.Conv2d(32, feature_dim, kernel_size=3, stride=2, padding=1),  # H/4
                    nn.ReLU(inplace=True)
                )
        
                # Decoder (upsampling)
                self.decoder = nn.Sequential(
                    nn.ConvTranspose2d(feature_dim, 32, kernel_size=4, stride=2, padding=1),  # H/2
                    nn.ReLU(inplace=True),
                    nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),  # back to original H
                    nn.Sigmoid()  # if you want values in [0,1] range for images
                )

            def forward(self, x):
                encoded = self.encoder(x)
                decoded = self.decoder(encoded)
                return decoded

class UNetAutoEncoder(nn.Module):
    def __init__(self, in_channels=3, feature_dim=64):
        super().__init__()

        # Encoder (downsampling)
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2)  # H/2

        self.enc2 = nn.Sequential(
            nn.Conv2d(32, feature_dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2)  # H/4

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(feature_dim, feature_dim * 2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # Decoder (upsampling with skip connections)
        self.up1 = nn.ConvTranspose2d(feature_dim * 2, feature_dim, kernel_size=2, stride=2)  # H/2
        self.dec1 = nn.Sequential(
            nn.Conv2d(feature_dim * 2, feature_dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.up2 = nn.ConvTranspose2d(feature_dim, 32, kernel_size=2, stride=2)  # H
        self.dec2 = nn.Sequential(
            nn.Conv2d(32 + 32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # Output
        self.output_layer = nn.Conv2d(32, in_channels, kernel_size=1)
        self.activation = nn.Sigmoid()

    def forward(self, x):
        # Encoder
        x1 = self.enc1(x)       # (B, 32, H, W)
        x2 = self.enc2(self.pool1(x1))  # (B, feature_dim, H/2, W/2)
        x3 = self.bottleneck(self.pool2(x2))  # (B, feature_dim*2, H/4, W/4)

        # Decoder with skip connections
        x = self.up1(x3)  # Upsample to H/2
        x = torch.cat([x, x2], dim=1)  # Skip connection
        x = self.dec1(x)

        x = self.up2(x)  # Upsample to H
        x = torch.cat([x, x1], dim=1)  # Skip connection
        x = self.dec2(x)

        return self.activation(self.output_layer(x))

class PatchEmbed(nn.Module):
    def __init__(self, img_size=128, patch_size=16, in_channels=3, embed_dim=256):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: [B, C, H, W] → [B, num_patches, embed_dim]
        x = self.proj(x)  # [B, embed_dim, H', W']
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        return x


class TransformerAutoEncoder(nn.Module):
    def __init__(self, img_size=128, patch_size=16, in_channels=3, embed_dim=256, depth=4, num_heads=4):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        self.num_patches = self.patch_embed.n_patches
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.img_size = img_size

        # Positional encoding
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        # Decoder (just a projection back to pixel space)
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, patch_size * patch_size * in_channels),
        )

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x) + self.pos_embed  # [B, num_patches, embed_dim]
        x = self.encoder(x)  # [B, num_patches, embed_dim]

        # Decode each patch embedding back to pixels
        x = self.decoder(x)  # [B, num_patches, patch*patch*C]
        x = x.view(B, self.img_size // self.patch_size, self.img_size // self.patch_size,
                   self.patch_size, self.patch_size, -1)  # reshape to image grid

        # Rearrange into image: [B, H, W, C] → [B, C, H, W]
        x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
        x = x.view(B, -1, self.img_size, self.img_size)
        return x
            


class TransformerAutoencoder(nn.Module):
    def __init__(
        self,
        img_size=128,
        patch_size=16,
        in_channels=3,
        embed_dim=512,
        depth=6,
        num_heads=8,
        mlp_dim=1024,
    ):
        super().__init__()

        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_dim = in_channels * patch_size * patch_size

        # Patch embedding
        self.patch_embed = nn.Linear(self.patch_dim, embed_dim)
        self.patch_unembed = nn.Linear(embed_dim, self.patch_dim)

        # Positional encodings
        self.pos_embed_enc = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
        self.pos_embed_dec = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

        # Learnable CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # Learnable decoder tokens
        self.decoder_tokens = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))

        # Encoder & decoder transformer layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=mlp_dim)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=mlp_dim)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=depth)

    def forward(self, x):
        B, C, H, W = x.shape

        # Flatten patches
        x = x.unfold(2, self.patch_size, self.patch_size) \
             .unfold(3, self.patch_size, self.patch_size)  # [B, C, H//P, W//P, P, P]
        x = x.contiguous().view(B, C, -1, self.patch_size, self.patch_size)  # [B, C, N, P, P]
        x = x.permute(0, 2, 1, 3, 4)  # [B, N, C, P, P]
        x = x.contiguous().view(B, self.num_patches, -1)  # [B, N, patch_dim]

        # Embed patches
        x = self.patch_embed(x)  # [B, N, D]

        # Add CLS token to encoder input
        cls_token = self.cls_token.expand(B, -1, -1)  # [B, 1, D]
        x = torch.cat([cls_token, x], dim=1)  # [B, 1+N, D]
        x = x + self.pos_embed_enc[:, :x.size(1), :]  # Add positional encoding

        # Encode
        memory = self.encoder(x)  # [B, 1+N, D]

        # Prepare decoder input (learnable tokens + pos)
        decoder_input = self.decoder_tokens.expand(B, -1, -1) + self.pos_embed_dec  # [B, N, D]

        # Decode
        decoded = self.decoder(tgt=decoder_input, memory=memory)  # [B, N, D]
        decoded = self.patch_unembed(decoded)  # [B, N, patch_dim]

        # Reshape to image
        decoded = decoded.view(B, self.num_patches, C, self.patch_size, self.patch_size)
        decoded = decoded.permute(0, 2, 1, 3, 4)  # [B, C, N, P, P]
        h = w = self.img_size // self.patch_size
        decoded = decoded.view(B, C, h, w, self.patch_size, self.patch_size)
        decoded = decoded.permute(0, 1, 2, 4, 3, 5).contiguous()
        decoded = decoded.view(B, C, H, W)

        return decoded
